
import math
from typing import Optional, List

import torch

from megatron.core import parallel_state
import megatron.core.transformer.moe.moe_utils
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region

get_capacity = megatron.core.transformer.moe.moe_utils.get_capacity
group_limited_topk = megatron.core.transformer.moe.moe_utils.group_limited_topk


def sequence_load_balancing_loss_func(
    probs: torch.Tensor,
    routing_map: torch.Tensor,
    batch_size: int,
    seq_length: int,
    topk: int,
    moe_aux_loss_coeff: float,
    sequence_partition_group=None,
    moe_device_level_aux_loss_coeff: float=None,
    moe_comm_aux_loss_coeff: float=None,
    moe_router_topk_limited_devices: float=None,
    moe_complementary_seq_aux_loss: bool=False,
):
    """
    Calculate the auxiliary loss in sequence-level by computing the loss for each individual sample.
    Refer to the DeepSeek-V2 huggingface repo
    (https://huggingface.co/deepseek-ai/DeepSeek-V2) for details.

    Args:
        probs (torch.Tensor): Softmax probabilities output by the router for each token.
                              Shape in [num_tokens, num_experts].
        routing_map (torch.Tensor): Mapping of tokens to experts assignment.
                                    Shape in [num_tokens, num_experts].
        batch_size (int): Batch size to process.
        seq_length (int): Sequence length to process.
        topk (int): Number of experts to route to for each token.
        moe_aux_loss_coeff (float): Scaling coefficient for the auxiliary loss.
        sequence_partition_group (optional): The parallel group over which the sequence is
                                             partitioned. If None, no partitioning is applied.
                                             Defaults to None.

    Returns:
        torch.Tensor: The sequence auxiliary loss for load balancing.
    """
    num_sub_sequence = 1
    num_experts = probs.shape[1]

    probs_for_aux_loss = probs.view(seq_length, batch_size, -1)
    routing_map = routing_map.view(seq_length, batch_size, -1)

    # If the sequence is partitioned by certain parallelism strategies like Sequence Parallelism
    # or Context Parallelism, compute the gradient of the auxiliary loss with respect to the full
    # sequence.
    if sequence_partition_group is not None:
        num_sub_sequence = torch.distributed.get_world_size(sequence_partition_group)
        seq_length *= num_sub_sequence
        probs_for_aux_loss = gather_from_sequence_parallel_region(
            probs_for_aux_loss, group=sequence_partition_group
        )

    cost_coeff = routing_map.sum(dim=0, dtype=torch.float).div_(seq_length * topk / num_experts)
    if moe_complementary_seq_aux_loss:
        assert (
            (moe_device_level_aux_loss_coeff is None) and 
            (moe_comm_aux_loss_coeff is None)
            ), "moe_complementary_seq_aux_loss only used in deepseekV3, which means no other aux loss used"
        sum_value = probs_for_aux_loss.sum(dim=-1, keepdim=True)
        probs_for_aux_loss = probs_for_aux_loss / (sum_value + 1e-20)
    seq_aux_loss = (cost_coeff * probs_for_aux_loss.mean(dim=0)).sum(dim=1).mean()
    seq_aux_loss *= moe_aux_loss_coeff

    if moe_device_level_aux_loss_coeff is not None:
        num_group = (
        parallel_state.get_expert_model_parallel_world_size()
        )  # num_group equals to expert parallel size
        device_aux_loss = (cost_coeff.view(batch_size, num_group, -1).mean(dim=2) * 
                           probs_for_aux_loss.mean(dim=0).view(batch_size, num_group, -1).sum(dim=2)).sum(dim=1).mean()
        device_aux_loss *= moe_device_level_aux_loss_coeff
        seq_aux_loss += device_aux_loss
    if moe_comm_aux_loss_coeff is not None:
        num_group = (
        parallel_state.get_expert_model_parallel_world_size()
        )  # num_group equals to expert parallel size
        cost_coeff = routing_map.view(seq_length, batch_size, num_group, -1).any(dim=3).sum(dim=0).float()
        cost_coeff.div_(seq_length *  moe_router_topk_limited_devices / num_group)
        comm_aux_loss = (cost_coeff * 
                           probs_for_aux_loss.mean(dim=0).view(batch_size, num_group, -1).sum(dim=2)).sum(dim=1).mean()
        comm_aux_loss *= moe_comm_aux_loss_coeff
        seq_aux_loss += comm_aux_loss
        
    return seq_aux_loss

def topk_softmax_with_capacity(
    logits: torch.Tensor,
    topk: int,
    capacity_factor: Optional[float] = None,
    pad_to_capacity: bool = False,
    drop_policy: str = "probs",
    use_pre_softmax: bool = False,
    num_groups: Optional[int] = None,
    group_topk: Optional[int] = None,
    scaling_factor: Optional[float] = None,
    deterministic_mode: bool = False,
    score_function: str = "softmax",
    expert_bias: Optional[torch.Tensor] = None,
    device_level_capacity: bool = False,
):
    """Apply capacity and padding to the top-k selection.
    Args:
        logits (torch.Tensor): Logits tensor.
        topk (int): The number of experts to select for each token.
        capacity_factor (float): The capacity factor of each expert. Will drop tokens if the number
                               of tokens exceeds the capacity.
        pad_to_capacity (bool): Whether to need padding in token drop mode. The probs for padded
                               tokens will be 0.
        drop_policy (str): The policy to drop tokens. Can be either "prob" or "position".
                           If "prob", the tokens with the lowest probabilities will be dropped.
                           If "position", tokens at the end of each batch will be dropped.
        use_pre_softmax (bool): Whether to apply softmax before top-k selection.
        num_groups (int): Number of groups for routed experts.
        group_topk (int): Number of selected groups for each token.
        scaling_factor (float): Scaling factor of routing score in top-k selection.
        deterministic_mode (bool): Deprecated.
        score_function (str): The score function to use. Can be either "softmax" or "sigmoid".
        expert_bias (torch.Tensor): The bias added to logits for expert routing.

    Returns:
        Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            - routing_probs (torch.Tensor): A tensor of shape [num_tokens, num_experts] containing
              the routing probabilities for each token to each expert.
            - routing_map (torch.Tensor): A mask tensor of shape [num_tokens, num_experts]
              indicating which experts were selected for each token. True values represent
              the selected experts.
            - tokens_per_expert (torch.Tensor): A tensor of shape [num_experts] containing
              the number of local tokens assigned to each expert before dropping and padding.
    """
    assert logits.dim() == 2, f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}."
    num_tokens, num_experts = logits.shape

    def compute_topk(scores, topk, num_groups=None, group_topk=None):
        if group_topk:
            return group_limited_topk(
                scores=scores,
                topk=topk,
                num_tokens=num_tokens,
                num_experts=num_experts,
                num_groups=num_groups,
                group_topk=group_topk,
            )
        else:
            return torch.topk(scores, k=topk, dim=1)

    if score_function == "softmax":
        if use_pre_softmax:
            scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
            probs, top_indices = compute_topk(scores, topk, num_groups, group_topk)
        else:
            scores, top_indices = compute_topk(logits, topk, num_groups, group_topk)
            probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
    elif score_function == "sigmoid":
        scores = torch.sigmoid(logits)
        if expert_bias is not None:
            scores_for_routing = scores + expert_bias
            _, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk)
            scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits)
        else:
            scores, top_indices = compute_topk(scores, topk, num_groups, group_topk)
        probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
    else:
        raise ValueError(f"Invalid score_function: {score_function}")

    if scaling_factor:
        probs = probs * scaling_factor

    # TODO Try using element-wise operations instead of scatter?
    topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs)
    topk_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool()
    tokens_per_expert = topk_map.sum(dim=0)

    if capacity_factor is None:
        # TopK without capacity
        return topk_masked_gates, topk_map, tokens_per_expert
    elif device_level_capacity:
        assert drop_policy=='probs', f"only support 'probs' for device_level capacity, but get {drop_policy}"
        num_group = (
        parallel_state.get_expert_model_parallel_world_size()
        )  # num_group equals to expert parallel size
        device_expert_capacity = get_capacity(
            num_tokens=num_tokens * topk, num_experts=num_experts, capacity_factor=capacity_factor
        )*num_experts//num_group
        # Maskout exceeded tokens
        if drop_policy == "probs":
            topk_masked_group_gates = topk_masked_gates.view(num_tokens, num_group, -1)
            topk_masked_group_gates = topk_masked_group_gates.permute(0,2,1).reshape(-1, num_group)
            _, capacity_indices = torch.topk(
                topk_masked_group_gates, k=device_expert_capacity, dim=0, sorted=False
            )
            capacity_mask = torch.zeros([num_tokens*num_experts//num_group, num_group], device=logits.device).scatter(0, capacity_indices, 1).bool()
            capacity_mask = capacity_mask.view(num_tokens, num_experts//num_group, num_group).permute(0,2,1).reshape(num_tokens, -1)
        else:
            raise ValueError(f"Invalid drop_policy: {drop_policy}")

        if pad_to_capacity:
            final_map = capacity_mask
            final_probs = topk_masked_gates * final_map
        else:
            # Get exceed mask and maskout exceeded probs and indices
            final_map = torch.logical_and(topk_map, capacity_mask)
            final_probs = topk_masked_gates * final_map
        return final_probs, final_map, tokens_per_expert
    else:
        # TopK with capacity
        expert_capacity = get_capacity(
            num_tokens=num_tokens * topk, num_experts=num_experts, capacity_factor=capacity_factor
        )

        # Maskout exceeded tokens
        if drop_policy == "probs":
            _, capacity_indices = torch.topk(
                topk_masked_gates, k=expert_capacity, dim=0, sorted=False
            )
            capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1).bool()
        elif drop_policy == "position":
            _, capacity_indices = torch.topk(topk_map.int(), k=expert_capacity, dim=0, sorted=False)
            capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1).bool()
        else:
            raise ValueError(f"Invalid drop_policy: {drop_policy}")

        if pad_to_capacity:
            final_map = capacity_mask
            final_probs = topk_masked_gates * final_map
        else:
            # Get exceed mask and maskout exceeded probs and indices
            final_map = torch.logical_and(topk_map, capacity_mask)
            final_probs = topk_masked_gates * final_map
        return final_probs, final_map, tokens_per_expert


megatron.core.transformer.moe.moe_utils.sequence_load_balancing_loss_func = sequence_load_balancing_loss_func
megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity = topk_softmax_with_capacity